package quickml.supervised.parametricModels;

import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import org.javatuples.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.SparseRegressionInstance;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static quickml.MathUtils.cappedlogBase2;
import static quickml.MathUtils.sigmoid;

/**
 * Created by alexanderhawk on 4/1/16.
 */
public class ParallelizedLogisticDerivative implements OptimizableCostFunction<SparseRegressionInstance> {

    private int executorThreadCount = Runtime.getRuntime().availableProcessors();
    private ExecutorService executorService=Executors.newFixedThreadPool(executorThreadCount);


    public static final String EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER = "expectedFractionOfFeaturesToUpdatePerWorker";
    public static final String EXECUTOR_THREAD_COUNT = "executorThreadCount";
    public static final String MIN_INSTANCES_FOR_PARELLIZATION = "minInstancesForParrellization";
    public static final String SPARSE_PARELLIZATION = "sparseParallelization";
    public static final String OPTIMIZABLE_COST_FUNCTION = "optimizableCostFunction";
    public static final String MAX_GRADIENT_NORM = "maxGradientNorm";
    public static final String RIDGE = "ridge";
    public static final String LASSO = "lasso";

    double ridge = 0;
    double lasso = 0;
    private double expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
    private int minInstancesForParrellization = 100;
    private boolean sparseParallelization = true;
    private double maxGradientNorm = Double.MAX_VALUE;


    private static final Logger logger = LoggerFactory.getLogger(ParallelizedLogisticDerivative.class);

    public ParallelizedLogisticDerivative executorThreadCount(int executorThreadCount) {
        if (executorThreadCount < this.executorThreadCount) {
            this.executorThreadCount = executorThreadCount;
        } else {
            logger.warn("can't use more executors than cores");
        }
        return this;
    }
    public ParallelizedLogisticDerivative maxGradientNorm(double maxGradientNorm) {
        this.maxGradientNorm = maxGradientNorm;
        return this;
    }
    public ParallelizedLogisticDerivative minInstancesForParrellization(int minInstancesForParrellization) {
        this.minInstancesForParrellization = minInstancesForParrellization;
        return this;
    }

    public ParallelizedLogisticDerivative expectedFractionOfFeaturesToUpdatePerWorker(double expectedFractionOfFeaturesToUpdatePerWorker) {
        this.expectedFractionOfFeaturesToUpdatePerWorker = expectedFractionOfFeaturesToUpdatePerWorker;
        return this;
    }

    public ParallelizedLogisticDerivative sparseParallelization(boolean sparseParallelization) {
        this.sparseParallelization = sparseParallelization;
        return this;
    }

    public ParallelizedLogisticDerivative ridgeRegularizationConstant(final double ridgeRegularizationConstant) {
        this.ridge = ridgeRegularizationConstant;
        return this;
    }

    public ParallelizedLogisticDerivative lassoRegularizationConstant(final double ridgeRegularizationConstant) {
        this.lasso = ridgeRegularizationConstant;
        return this;
    }

    public static double probabilityOfTheNegativeClass(double[] weights, SparseRegressionInstance instance) {
        return 1.0 - probabilityOfThePositiveClass(weights, instance);
    }

    public static double probabilityOfThePositiveClass(double[] weights, SparseRegressionInstance instance) {
        return sigmoid(instance.dotProduct(weights));
    }

    public void updateBuilderConfig(final Map<String, Serializable> config) {

        if (config.containsKey(SPARSE_PARELLIZATION)) {
            sparseParallelization((Boolean) config.get(SPARSE_PARELLIZATION));
        }
        if (config.containsKey(MIN_INSTANCES_FOR_PARELLIZATION)) {
            minInstancesForParrellization((Integer) config.get(MIN_INSTANCES_FOR_PARELLIZATION));
        }

        if (config.containsKey(EXECUTOR_THREAD_COUNT)) {
            executorThreadCount((Integer) config.get(EXECUTOR_THREAD_COUNT));
        }
        if (config.containsKey(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER)) {
            expectedFractionOfFeaturesToUpdatePerWorker((Double) config.get(EXPECTED_FRACTION_OF_FEATURES_TO_UPDATE_PER_WORKER));
        }
         if (config.containsKey(MAX_GRADIENT_NORM)) {
            maxGradientNorm((Double) config.get(MAX_GRADIENT_NORM));
         }
        if (config.containsKey(LASSO)) {
            ridgeRegularizationConstant((Double) config.get(LASSO));
        }
        if (config.containsKey(RIDGE)) {
            lassoRegularizationConstant((Double) config.get(RIDGE));
        }

    }

        @Override
    public  double computeCost(List<? extends SparseRegressionInstance> instances, double[] weights, double minPredictedProbablity) {
        double cost = 0.0;
        for (SparseRegressionInstance instance : instances) {
            if ((double) instance.getLabel() == 1.0) {
                cost += -cappedlogBase2(probabilityOfThePositiveClass(weights, instance), minPredictedProbablity);
            } else if ((double) instance.getLabel() == 0.0) {
                cost += -cappedlogBase2(probabilityOfTheNegativeClass(weights, instance), minPredictedProbablity);
            }
        }
        cost += getRegularizationCost(weights, ridge, lasso);
        cost /= instances.size();


        return cost;

    }

    @Override
    public void updateGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, final double[] gradient) {
        int currentMiniBatchSize = sparseClassifierInstances.size();
        final int[] threadStartAndStopIndices = getThreadStartIndices(currentMiniBatchSize, executorThreadCount, minInstancesForParrellization);
        int actualNumThreads = threadStartAndStopIndices.length - 1;
        if (sparseParallelization) {
            sparseCalculationOfGradient(sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);
        } else {
            nonSparseCalculationOfGradient(sparseClassifierInstances, fixedWeights, gradient, threadStartAndStopIndices, actualNumThreads);

        }
        addRegularizationComponentOfTheGradient(fixedWeights, gradient, ridge, lasso);
        normalizeTheGradient(currentMiniBatchSize, maxGradientNorm, gradient);
    }



    private  void sparseCalculationOfGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
        List<Future<Int2DoubleOpenHashMap>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
        for (int i = 0; i < actualNumThreads; i++) {
            final int index = i;
            contributionsToTheGradient.add(executorService.submit(new Callable<Int2DoubleOpenHashMap>() {
                @Override
                public Int2DoubleOpenHashMap call() throws Exception {
                    expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
                    try {
                        Int2DoubleOpenHashMap sparseWorkerContributionToTheGradient = getSparseWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights, expectedFractionOfFeaturesToUpdatePerWorker);
                        return sparseWorkerContributionToTheGradient;

                    } catch (IllegalArgumentException e) {
                        logger.info("what?");
                        throw new RuntimeException(e);
                    }
                }
            }));
        }
        sparseReductionToTheGradient(gradient, contributionsToTheGradient);
    }

    private void nonSparseCalculationOfGradient(final List<? extends SparseRegressionInstance> sparseClassifierInstances, final double[] fixedWeights, double[] gradient, final int[] threadStartAndStopIndices, int actualNumThreads) {
        List<Future<double[]>> contributionsToTheGradient = Lists.newArrayListWithCapacity(actualNumThreads);
        for (int i = 0; i < actualNumThreads; i++) {
            final int index = i;
            contributionsToTheGradient.add(executorService.submit(new Callable<double[]>() {
                @Override
                public double[] call() throws Exception {
                    expectedFractionOfFeaturesToUpdatePerWorker = 1.0;
                    return getWorkerContributionToTheGradient(sparseClassifierInstances.subList(threadStartAndStopIndices[index], threadStartAndStopIndices[index + 1]), fixedWeights);
                }
            }));
        }
        reductionToTheGradient(gradient, contributionsToTheGradient);
    }



    static void sparseUpdateUnnormalizedGradientForInstance(double[] weights, Int2DoubleOpenHashMap contributionsToTheGradient,
                                                            SparseRegressionInstance instance) {
        //could do this with a map for truly sparse instances...but
        double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);


        Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
        int[] indices = sparseAttributes.getValue0();
        double[] values = sparseAttributes.getValue1();
        for (int i = 0; i < indices.length; i++) {
            int featureIndex = indices[i];
            contributionsToTheGradient.addTo(featureIndex, gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]));
        }
    }

    private static double gradientContributionOfAFeatureValue(double label, double postiveClassProbability, double value) {
        return -(label - postiveClassProbability) * value;
    }

    static void updateUnnormalizedGradientForInstance(double[] weights, double[] contributionsToTheGradient,
                                                      SparseRegressionInstance instance) {
        //could do this with a map for truly sparse instances...but
        double postiveClassProbability = probabilityOfThePositiveClass(weights, instance);
        Pair<int[], double[]> sparseAttributes = instance.getSparseAttributes();
        int[] indices = sparseAttributes.getValue0();
        double[] values = sparseAttributes.getValue1();
        for (int i = 0; i < indices.length; i++) {
            int featureIndex = indices[i];
            contributionsToTheGradient[featureIndex] += gradientContributionOfAFeatureValue((Double) instance.getLabel(), postiveClassProbability, values[i]);
        }
    }

    public  Int2DoubleOpenHashMap getSparseWorkerContributionToTheGradient(List<? extends SparseRegressionInstance> instances, double[] weights, double expectedFractionOfFeaturesToUpdate) {
        Int2DoubleOpenHashMap contributionsToTheGradient = new Int2DoubleOpenHashMap((int) (expectedFractionOfFeaturesToUpdate * weights.length));
        contributionsToTheGradient.defaultReturnValue(0.0);
        for (SparseRegressionInstance instance : instances) {
            sparseUpdateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
        }
        return contributionsToTheGradient;
    }
    public static void sparseReductionToTheGradient(double[] gradient, List<Future<Int2DoubleOpenHashMap>> contributions) {
        for (Future<Int2DoubleOpenHashMap> contribution : contributions) {
            addSparseContribution(gradient, contribution);
        }
    }

    public static void addSparseContribution(double[] gradient, Future<Int2DoubleOpenHashMap> contributionFuture) {
        try {
            Int2DoubleOpenHashMap contribution = contributionFuture.get();
            for (Int2DoubleMap.Entry entry : contribution.int2DoubleEntrySet()) {
                gradient[entry.getKey()] += entry.getValue();
            }
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    public static double[] getWorkerContributionToTheGradient(List<? extends SparseRegressionInstance> instances, double[] weights) {
        double[] contributionsToTheGradient = new double[weights.length];
        for (SparseRegressionInstance instance : instances) {
            updateUnnormalizedGradientForInstance(weights, contributionsToTheGradient, instance);
        }
        return contributionsToTheGradient;
    }
    public static void  reductionToTheGradient(double[] gradient, List<Future<double[]>> contributions) {
        for (Future<double[]> contribution : contributions) {
            addContribution(gradient, contribution);
        }
    }



    public static void addRegularizationComponentOfTheGradient(double[] weights, double[] gradient, double ridge, double lasso) {
        for (int i = 1; i < weights.length; i++) {//start at 1 to skip the bias term
            double lassoDerivative = lasso;
            if (weights[i] < 0.0) {
                lassoDerivative *= -1;
            }
            gradient[i] += ridge * weights[i] + lassoDerivative;
        }
    }


    public static void addContribution(double[] gradient, Future<double[]> contributionFuture) {
        try {
            double[] contribution = contributionFuture.get();
            for (int i = 0; i< gradient.length; i++) {
                gradient[i] += contribution[i];
            }
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }


    public static double getRegularizationCost(double[] weights, double ridge, double lasso) {
        double cost = 0;
        for (int i = 0; i < weights.length; i++) {
            cost += weights[i] * weights[i] * ridge / 2.0 + Math.abs(weights[i]) * lasso;
        }
        return cost;
    }


    public static void normalizeTheGradient(int minibatchSize, double maxGradientNorm, double[] gradient) {
        for (int i = 1; i < gradient.length; i++) {
            gradient[i] /= minibatchSize;
        }
        if (maxGradientNorm != Double.MAX_VALUE) {
            applyMaxGradientNorm(maxGradientNorm, gradient);
        }
    }




    public static void applyMaxGradientNorm(double maxGradientNorm, double[] gradient) {
        double gradientSumOfSquares = 0;
        for (double g : gradient) {
            gradientSumOfSquares += Math.pow(g, 2);
        }
        double gradientNorm = Math.sqrt(gradientSumOfSquares);
        if (gradientNorm > maxGradientNorm) {
            double n = gradientNorm / maxGradientNorm;
            for (int i = 0; i < gradient.length; i++) {
                gradient[i] = gradient[i] / Math.sqrt(n);
            }
        }
    }

    //this methods seems wack
    public  int[] getThreadStartIndices(int numInstances, int actualNumThreads, int minInstancesForParrallization) {
        if (actualNumThreads < minInstancesForParrallization) {
            int[] threadStartIndices = new int[2];
            threadStartIndices[0] = 0;
            threadStartIndices[1] = actualNumThreads;
            return threadStartIndices;
        } else if (actualNumThreads <= executorThreadCount) {
            actualNumThreads = actualNumThreads;
            int[] threadStartIndices = new int[actualNumThreads+1];
            for (int i = 0; i < actualNumThreads; i++) {
                threadStartIndices[i] = i;
            }
            threadStartIndices[actualNumThreads] =actualNumThreads; //could be put in loop but follow the convention of putting final stop index outside
            return threadStartIndices;
        }



        int[] threadStartIndices = new int[executorThreadCount + 1];

        int lowerSamplesPerThread = numInstances / executorThreadCount;
        int upperSamplesPerThread = numInstances / executorThreadCount + 1;
        int remainder = numInstances % executorThreadCount;
        int currentStartIndex = 0;
        for (int i = 0; i < executorThreadCount; i++) {
            threadStartIndices[i] = currentStartIndex;
            if (i >= executorThreadCount - remainder) {
                currentStartIndex += upperSamplesPerThread;
            } else {
                currentStartIndex += lowerSamplesPerThread;
            }
        }
        threadStartIndices[executorThreadCount] = numInstances;
        return threadStartIndices;

    }

    @Override
    public void shutdown(){
        executorService.shutdown();
    }

}
